# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from torch import nn

from atb_llm.models.base.modeling import get_suffix, FlashAttention, MLP, FlashLayer
from atb_llm.utils.layers import (
    TensorParallelColumnLinear,
    TensorEmbedding,
    TensorParallelEmbedding,
    RMSNorm
)


LLAMA_EMBEDDING_PARALLEL_THRESHOLD = 128256  # vocab size of llama3


class LlamaMLP(MLP):
    def __init__(self, prefix, config, weights, **kwargs):
        super().__init__(prefix, config, weights, **kwargs)
        self.load_weights(**kwargs)


class FlashLlamaAttention(FlashAttention):
    def __init__(self, prefix: str, config, weights, **kwargs):
        super().__init__(prefix, config, weights, **kwargs)
        self.load_weights(**kwargs)

    def load_qkv_weights(self, **kwargs):
        if self.config.model_type == "zhinao":
            query_key_value_linear = TensorParallelColumnLinear.load(
                self.config,
                prefix=f"{self.prefix}.qkv_proj",
                weights=self.weights,
                bias=True
            )
            setattr(self, get_suffix(self.pack_name), query_key_value_linear)
        else:
            super().load_qkv_weights(**kwargs)


class FlashLlamaLayer(FlashLayer):
    def __init__(self, layer_id, config, weights, model_prefix="model", **kwargs):
        super().__init__(layer_id, config, weights, model_prefix, **kwargs)
        self.load_weights(**kwargs)

    def load_weights(self, **kwargs):
        self.self_attn = FlashLlamaAttention(
            prefix=f"{self.prefix}.self_attn", config=self.config, weights=self.weights, **kwargs
        )
        self.mlp = LlamaMLP(prefix=f"{self.prefix}.mlp", config=self.config, weights=self.weights, **kwargs)
        super().load_weights(**kwargs)


class FlashLlamaModel(torch.nn.Module):
    def __init__(self, config, weights, model_prefix="model", **kwargs):
        super().__init__()

        self.parallel_embedding = config.vocab_size >= LLAMA_EMBEDDING_PARALLEL_THRESHOLD
        self.quantize = config.quantize

        if self.quantize == 'w8a8sc' or not self.parallel_embedding:
            self.embed_tokens = TensorEmbedding(
                prefix=f"{model_prefix}.embed_tokens", weights=weights
            )
        elif self.parallel_embedding:
            self.embed_tokens = TensorParallelEmbedding(
                prefix=f"{model_prefix}.embed_tokens", weights=weights
            )

        self.layers = nn.ModuleList(
            [
                FlashLlamaLayer(layer_id, config, weights, model_prefix, **kwargs)
                for layer_id in range(config.num_hidden_layers)
            ]
        )
        self.norm = RMSNorm(
            prefix=f"{model_prefix}.norm", weights=weights, eps=config.rms_norm_eps
        )
